#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author ：hhx
@Date ：2022/5/21 21:53 
@Description ：卷积对抗自编码器
"""

import argparse
import torch
import pickle
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

z_dim = 64


##################################
# Define Networks
##################################
# Encoder
class Q_net(nn.Module):
    def __init__(self):
        super(Q_net, self).__init__()
        self.enc1 = nn.Conv2d(in_channels=10, out_channels=8, kernel_size=3)
        self.enc2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3)
        self.enc3 = nn.Linear(in_features=28 * 28 * 16, out_features=z_dim)

    def forward(self, x):
        self.x1 = F.relu(self.enc1(x))
        self.x2 = F.relu(self.enc2(self.x1)).view(-1, 28 * 28 * 16)
        self.x3 = self.enc3(self.x2)
        return self.x3


# Decoder
class P_net(nn.Module):
    def __init__(self):
        super(P_net, self).__init__()
        # decoder
        self.dec1 = nn.Linear(in_features=z_dim, out_features=28 * 28 * 16)
        self.dec2 = nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=3)
        self.dec3 = nn.ConvTranspose2d(in_channels=8, out_channels=10, kernel_size=3)

    def forward(self, x):
        self.x4 = F.relu(self.dec1(F.relu(x))).reshape(-1, 16, 28, 28)
        self.x5 = F.relu(self.dec2(self.x4))
        x = self.dec3(self.x5)
        return F.sigmoid(x)


class D_net_gauss(nn.Module):
    def __init__(self):
        super(D_net_gauss, self).__init__()
        self.lin1 = nn.Linear(z_dim, 128)
        self.lin2 = nn.Linear(128, 256)
        self.lin3 = nn.Linear(256, 1)

    def forward(self, x):
        x = F.dropout(self.lin1(x), p=0.2, training=self.training)
        x = F.relu(x)
        x = F.dropout(self.lin2(x), p=0.2, training=self.training)
        x = F.relu(x)
        return F.sigmoid(self.lin3(x))
